/*
 * Copyright (C) 2023 Samsung Electronics Co. LTD
 *
 * This software is proprietary of Samsung Electronics.
 * No part of this software, either material or conceptual may be copied or
 * distributed, transmitted, transcribed, stored in a retrieval system or
 * translated into any human or computer language in any form by any means,
 * electronic, mechanical, manual or otherwise, or disclosed to third parties
 * without the express written permission of Samsung Electronics.
 *
 */

#ifndef IMDB_PRECISE_RANDOM_SCAN_RANGES_GEN_H
#define IMDB_PRECISE_RANDOM_SCAN_RANGES_GEN_H

#include "base.h"
#include "random.h"

#include "common/log.h"

#include "tools/datagen/imdb/general/columns_info.h"
#include "tools/datagen/imdb/general/ranges_info.h"

#include "pnmlib/imdb/scan_types.h"

#include <cassert>
#include <cmath>
#include <cstddef>
#include <functional>
#include <memory>
#include <random>
#include <stdexcept>
#include <utility>
#include <vector>

namespace tools::gen::imdb {

/** @brief Class for generating scan ranges with precise actual selectivity.
 * Algorithm uses binary search by length of scan range for solution search.
 * 1. Scan ranges are generated by standard RandomRangesGen
 * 2. Vector of partial sums of entry counts are calculated (see description of
 * partial_count_sum_)
 * 3. Each range is a little modified by algorithm with usage of binary search
 *  so it responded to required selectivity
 * 3.1. Range is constructed by length
 * 3.2. Actual selectivity is calculated on that range using precalculated
 * partial sums 3.3 If selectivity is equal to required selectivity with
 * precision - optimal range is found. Otherwise step of binary search is
 * performed
 */
class PreciseRandomRangesGen : public IRangesGenerator {

public:
  PreciseRandomRangesGen(std::shared_ptr<const compressed_vector> column,
                         size_t seed = std::random_device{}(),
                         double selectivity_epsilon = 0.001)
      : column_(std::move(column)), seed_(seed),
        selectivity_epsilon_{std::fabs(selectivity_epsilon)} {
    pnm::log::info("PreciseRandomRangesGen created with seed {}", seed_);
  }

private:
  Ranges create_scan_ranges_impl(const RangesInfo &scan_info,
                                 const ColumnsInfo &column_info) override {

    if (!column_) {
      throw std::runtime_error("Table column must be specified");
    }
    assert(column_->value_bits() == column_info.entry_bit_size());

    auto generator = RandomRangesGen(seed_);
    Ranges scan_ranges = generator.create(scan_info, column_info);
    assert(scan_ranges.size() == scan_info.num_requests());
    assert(scan_ranges.size() == scan_info.selectivities().size());
    const compressed_element_type entry_max_value =
        (1ULL << column_info.entry_bit_size()) - 1;

    // fill partial sums of count of entries
    {
      partial_count_sum_.clear();
      partial_count_sum_.resize(entry_max_value + 1, 0);
      for (const auto value : *column_) {
        ++partial_count_sum_[value];
      }

      size_t count = 0;
      for (auto &cur_partial_sum : partial_count_sum_) {
        count += cur_partial_sum;
        cur_partial_sum = count;
      }
    }

    for (size_t i = 0; i < scan_ranges.size(); ++i) {
      auto &range = scan_ranges[i];
      const auto selectivity = scan_info.selectivities()[i];
      assert(range.start <= range.end);
      assert(range.end <= entry_max_value);
      range = optimal_range(selectivity, 0, entry_max_value, range.start);
    }

    return scan_ranges;
  }

  /** @brief Function for performing binary search.
   * Additional bool flag (solution_found) in result of compare_func is true
   * when selectivity is found with required precision.
   *
   * @param bottom minimal value of length scan range
   * @param top maximal value of length of scan range
   * @param compare_func - function takes as argument length of scan range and
   * returns pair<bool, bool> - first is true if solution found second is true
   * if selectivity for given length is bigger than target selectivity
   * @result optimal length of scan range
   */
  static compressed_element_type
  find_solution(compressed_element_type bottom, compressed_element_type top,
                std::function<std::pair<int, int>(compressed_element_type)>
                    &&compare_func) {
    top++;
    while (bottom < top) {
      // overflow safe (bottom + top)/ 2
      const compressed_element_type mid =
          bottom / 2 + top / 2 + (bottom % 2 + top % 2) / 2;
      const auto [solution_found, is_more] = compare_func(mid);

      if (solution_found) {
        return mid;
      }

      if (is_more) {
        top = mid;
      } else {
        bottom = mid + 1;
      }
    }
    return bottom;
  }

  /** @brief This function of constructing scan_ranges by length guarantees
   * monotonicity of the selectivity function of length which is necessary for
   * binary search. Though the disadvantage is that when length is too big
   * range.end will be equal to max_value that doesn't satisfy the condition of
   * randomness of scan_ranges.
   *
   * @param max_value maximal value of end of scan range
   * @param length - length of scan range
   * @param range_start - preferable scan range start
   * @return scan range
   */
  static RangeOperation calc_next_range(compressed_element_type max_value,
                                        compressed_element_type length,
                                        compressed_element_type range_start) {
    RangeOperation range;

    if (length > max_value - range_start) {
      range.end = max_value;
      range.start = max_value - length;
    } else {
      range.start = range_start;
      range.end = range_start + length;
    }

    return range;
  }

  /** @brief Calculates range that responds to required selectivity with
   * precision using binary search. Search is performed by length of scan range.
   * Scan range are constructed in such way that increasing of length leads to
   * increasing of selectivity on this range. And vice versa - reducing length
   * leads to reducing selectivity.
   *
   * @param selectivity required selectivity
   * @param min_value minimal value of start of scan range
   * @param max_value maximal value of end of start range
   * @param init_scan_range preferable start of scan range
   * @return optimal scan range
   */
  RangeOperation optimal_range(double selectivity,
                               compressed_element_type min_value,
                               compressed_element_type max_value,
                               compressed_element_type init_range_start) const {
    const compressed_element_type min_length = 0;
    const compressed_element_type max_length = max_value - min_value;

    // Binary search is performed by scan_range length.
    auto compare_func =
        [&](compressed_element_type length) -> std::pair<bool, bool> {
      const auto range = calc_next_range(max_value, length, init_range_start);
      const auto actual_selectivity = calc_actual_selectivity(range);
      const bool solution_found =
          std::fabs(actual_selectivity - selectivity) <= selectivity_epsilon_;
      const bool is_more = actual_selectivity > selectivity;
      return std::make_pair(solution_found, is_more);
    };

    const auto length =
        find_solution(min_length, max_length, std::move(compare_func));
    return calc_next_range(max_value, length, init_range_start);
  }

  /** @brief Function for fast actual selectivity calculation because
   * partial sums are already precalculated on start
   *
   * @param range scan range
   * @return actual selectivity
   */
  double calc_actual_selectivity(const RangeOperation &range) const {
    const auto left_sum =
        range.start == 0 ? 0 : partial_count_sum_[range.start - 1];
    const auto right_sum = partial_count_sum_[range.end];
    return (right_sum - left_sum) / static_cast<double>(column_->size());
  }

  std::shared_ptr<const compressed_vector> column_;
  size_t seed_;
  double selectivity_epsilon_;

  /** @brief vector of partial sums, where index is entry in compressed column,
   * value is sum of number of entries that is less than index For example,
   *  column - {1,3,2,3,1}, vector of counts - {0, 2, 1, 2}, vector of
   * partial_count_sums - {0, 2, 3, 5} Since column entry <= 262143, so size of
   * vector <= 262144
   */
  std::vector<size_t> partial_count_sum_;
};

} // namespace tools::gen::imdb

#endif // IMDB_PRECISE_RANDOM_SCAN_RANGES_GEN_H
